{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Codegen 模块\n",
    "\n",
    "codegen 模块和 MSCGraph 一起使用，用于将 MSCGraph 转译成 Python 脚本或 C++ 脚本。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/media/pc/data/lxw/ai/tvm-book/doc/read/msc\n"
     ]
    }
   ],
   "source": [
    "%cd ..\n",
    "from pathlib import Path\n",
    "\n",
    "temp_dir = Path(\".temp\")\n",
    "temp_dir.mkdir(exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(inp_0: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">4</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">6</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">4</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            conv2d: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">6</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">4</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>conv2d(inp_0, metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">0</span>], strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], padding<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>], dilation<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], groups<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, data_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, kernel_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;OIHW&quot;</span>, out_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)\n",
       "            relu: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">6</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">4</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>relu(conv2d)\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(relu)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> relu\n",
       "\n",
       "<span style=\"color: #007979; font-style: italic\"># Metadata omitted. Use show_meta=True in script() method to show it.</span>\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/codegen/codegen.py:74: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  state_dict = torch.load(folder.relpath(graph.name + \".pth\"))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "main(\n",
       "  (conv2d): Conv2d(3, 6, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "  (relu): ReLU()\n",
       ")"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen\n",
    "from tvm.contrib.msc.framework.torch import codegen as torch_codegen\n",
    "from tvm.contrib.msc.framework.torch.frontend import translate as torch_translate\n",
    "from tvm.contrib.msc.core import utils as msc_utils\n",
    "from tvm.contrib.msc.core.frontend import translate\n",
    "from graph.model import get_model\n",
    "\n",
    "input_info = [((1, 3, 4, 4), \"float32\")] # 给定输入 shape 和数据类型\n",
    "mod, torch_fx_model = get_model(input_info)\n",
    "graph, weights = translate.from_relax(mod)\n",
    "build_folder = msc_utils.msc_dir(f\"{temp_dir}/tvm_test\")\n",
    "mod = tvm_codegen.to_relax(graph, weights, build_folder=build_folder)\n",
    "mod.show()\n",
    "\n",
    "build_folder = msc_utils.msc_dir(f\"{temp_dir}/torch_test\")\n",
    "graph, weights = torch_translate.from_torch(torch_fx_model, input_info, via_relax=True)\n",
    "model = torch_codegen.to_torch(graph, weights, build_folder=build_folder)\n",
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "````{dropdown} 产生的代码片段：\n",
    "```{code-block} python\n",
    ":caption: .temp/tvm_test/main.py\n",
    "import os\n",
    "import numpy as np\n",
    "from typing import List, Dict, Any\n",
    "import tvm\n",
    "from tvm.contrib.msc.core import utils as msc_utils\n",
    "from tvm import relax\n",
    "\n",
    "# Define the helpers\n",
    "def load_data(name: str, shape: List[int], dtype: str) -> np.ndarray:\n",
    "  path = os.path.join(\"baseline\", name + \".bin\")\n",
    "  if os.path.isfile(path):\n",
    "    data = np.fromfile(path, dtype=dtype).reshape(shape)\n",
    "  else:\n",
    "    data = np.ones((shape)).astype(dtype)\n",
    "  return data\n",
    "\n",
    "\n",
    "# Define the graph\n",
    "def main(res_0: relax.Var) -> tvm.IRModule:\n",
    "  inputs = [res_0]\n",
    "  # Define the weights\n",
    "  weight_1 = relax.Var(\"const\", relax.TensorStructInfo([6, 3, 1, 1], \"float32\"))\n",
    "  inputs.append(weight_1)\n",
    "  # Define the module\n",
    "  block_builder = relax.BlockBuilder()\n",
    "  with block_builder.function(name=\"main\", params=inputs.copy()):\n",
    "    # conv2d(nn.conv2d): <res_0> -> <res_1>\n",
    "    res_1 = relax.op.nn.conv2d(res_0, weight_1, strides=[1, 1], padding=[0, 0, 0, 0], dilation=[1, 1], groups=1, data_layout=\"NCHW\", kernel_layout=\"OIHW\", out_layout=\"NCHW\", out_dtype=\"float32\")\n",
    "    res_1 = block_builder.emit(res_1, name_hint=\"conv2d\")\n",
    "    # relu(nn.relu): <res_1> -> <res_2>\n",
    "    res_2 = relax.op.nn.relu(res_1)\n",
    "    res_2 = block_builder.emit(res_2, name_hint=\"relu\")\n",
    "    # Emit the outputs\n",
    "    block_builder.emit_func_output(res_2)\n",
    "  mod = block_builder.finalize()\n",
    "  return mod\n",
    "\n",
    "\n",
    "# Define the test\n",
    "if __name__ == \"__main__\":\n",
    "  # Prepare test datas\n",
    "  inputs = {}\n",
    "  golden = {}\n",
    "  inputs[\"inp_0\"] = load_data(\"inp_0\", [1, 3, 4, 4], \"float32\")\n",
    "  golden[\"relu\"] = load_data(\"relu\", [1, 6, 4, 4], \"float32\")\n",
    "  # Build and inference the graph\n",
    "  res_0 = relax.Var(\"inp_0\", relax.TensorStructInfo([1, 3, 4, 4], \"float32\"))\n",
    "  # Build Module\n",
    "  mod = main(res_0)\n",
    "  # Load weights\n",
    "  with open(\"main_params.bin\", \"rb\") as f:\n",
    "    params = tvm.runtime.load_param_dict(f.read())\n",
    "  bind_params = tvm.relax.transform.BindParams(\"main\", params)\n",
    "  mod = bind_params(mod)\n",
    "  target = tvm.target.Target(\"llvm\")\n",
    "  mod = tvm.relax.transform.LegalizeOps()(mod)\n",
    "  with tvm.transform.PassContext(opt_level=3):\n",
    "    ex = relax.build(mod, target)\n",
    "    vm = relax.VirtualMachine(ex, tvm.cpu())\n",
    "  f_main = vm[\"main\"]\n",
    "  outputs = f_main(inputs[\"inp_0\"])\n",
    "  msc_utils.compare_arrays(golden, outputs, verbose=\"detail\")\n",
    "```\n",
    "\n",
    "```{code-block} python\n",
    ":caption: .temp/torch_test/main.py\n",
    "import os\n",
    "import numpy as np\n",
    "from typing import List, Dict, Any\n",
    "import tvm\n",
    "from tvm.contrib.msc.core import utils as msc_utils\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import functional\n",
    "\n",
    "# Define the helpers\n",
    "def load_data(name: str, shape: List[int], dtype: str) -> np.ndarray:\n",
    "  path = os.path.join(\"baseline\", name + \".bin\")\n",
    "  if os.path.isfile(path):\n",
    "    data = np.fromfile(path, dtype=dtype).reshape(shape)\n",
    "  else:\n",
    "    data = np.ones((shape)).astype(dtype)\n",
    "  return data\n",
    "\n",
    "\n",
    "# Define the graph\n",
    "class main(torch.nn.Module):\n",
    "  def __init__(self: torch.nn.Module) -> torch.nn.Module:\n",
    "    super(main, self).__init__()\n",
    "    # conv2d(nn.conv2d): <res_0> -> <res_1>\n",
    "    self.conv2d = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=[1, 1], stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1, bias=False)\n",
    "    # relu(nn.relu): <res_1> -> <res_2>\n",
    "    self.relu = nn.ReLU()\n",
    "\n",
    "  def forward(self: torch.nn.Module, res_0: torch.Tensor) -> List[torch.Tensor]:\n",
    "    # conv2d(nn.conv2d): <res_0> -> <res_1>\n",
    "    res_1 = self.conv2d(res_0)\n",
    "    # relu(nn.relu): <res_1> -> <res_2>\n",
    "    res_2 = self.relu(res_1)\n",
    "    outputs = res_2\n",
    "    return outputs\n",
    "\n",
    "\n",
    "# Define the test\n",
    "if __name__ == \"__main__\":\n",
    "  # Prepare test datas\n",
    "  inputs = {}\n",
    "  golden = {}\n",
    "  inputs[\"inp_0\"] = load_data(\"inp_0\", [1, 3, 4, 4], \"float32\")\n",
    "  golden[\"relu\"] = load_data(\"relu\", [1, 6, 4, 4], \"float32\")\n",
    "  # Build and inference the graph\n",
    "  # Build Model\n",
    "  model = main()\n",
    "  # Load weights\n",
    "  weights = torch.load(\"main.pth\")\n",
    "  model.load_state_dict(weights)\n",
    "  res_0 = torch.from_numpy(inputs[\"inp_0\"])\n",
    "  outputs = model(res_0)\n",
    "  msc_utils.compare_arrays(golden, outputs, verbose=\"detail\")\n",
    "```\n",
    "````"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
